from qubiter.SEO_writer import *
from qubiter.adv_applications.StairsDerivCkt_writer import *
import collections as col


class StairsDerivThrCkt_writer(SEO_writer):
    """
    (Stairs Derivative Threaded Circuit) This class is a subclass of
    `SEO_writer`. It writes English and Picture files for several derivative
    circuits used for calculating the gradients of a quantum cost function (
    mean hamiltonian).

    This class calls class `StairsDerivCkt_writer` many times. Each time, 
    a new sub-circuit is built that acts on fresh set of qubits not used 
    before. So this class builds many threads, i.e., independent 
    sub-circuits which can be evolved in parallel. 

    If you compare the constructor of this "threaded" writer class with that 
    of the "non-threaded" writer class `StairsDerivCkt_writer`, you will see 
    that the constructor of this class has two fewer arguments, namely 
    `deriv_gate_str`, `has_neg_polarity`. All other arguments are the same for
    both classes. The reason for this is that all the possibilities for 
    those two arguments are included as (parallel, independent) subcircuits 
    of the giant quantum circuit created by this class. 

    Most attributes of this class are the same as those for 
    `StairsDerivCkt_writer`. One new attribute is 
    `subckt_to_large_small_bit_pair`. 

    subckts are labeled by a pair: (deriv_gate_str, has_neg_polarity).

    `subckt_to_large_small_bit_pair` maps each subckt to its edge qubits. 

    Attributes
    ----------
    deriv_direc : int
    dpart_name : str
    gate_str_to_rads_list : dict[int, list[float]]
    parent_num_qbits : int
    subckt_to_large_small_bit_pair :  dict[(str, bool), tuple[int]]

    """
    
    def __init__(self, deriv_direc, dpart_name,
                 gate_str_to_rads_list,
                 file_prefix, emb, **kwargs):
        """
        Constructor

        This constructor writes English and Picture files but it doesn't
        close those files after writing them. You must do that yourself
        using close_files().

        Parameters
        ----------
        deriv_direc : int
        dpart_name : str
        gate_str_to_rads_list : dict[int, list[float]]
        file_prefix : str
        emb : CktEmbedder
        kwargs : dict
            key-word arguments of SEO_writer

        Returns
        -------

        """
        SEO_writer.__init__(self, file_prefix, emb, **kwargs)
        self.deriv_direc = deriv_direc
        self.dpart_name = dpart_name
        self.gate_str_to_rads_list = gate_str_to_rads_list
        self.subckt_to_large_small_bit_pair = col.OrderedDict()

        last_key = tuple(self.gate_str_to_rads_list.keys())[-1]
        self.parent_num_qbits = 1+len(last_key)//2

        assert deriv_direc in range(4)

        self.write()

    def write(self):
        """
        This method writes English and Picture files for a giant quantum
        circuit which consists of many parallel, independent subcircuits,
        each subcircuit acting on a distinct set of qbits. The subcircuits
        are generated by calling `StairsDerivCkt_writer` for all the
        possibilities of `deriv_gate_str` and `has_neg_polarity`

        Returns
        -------

        """
        tot_num_qbits = self.get_tot_num_qbits()
        cum_nbits = 0
        for deriv_gate_str in self.gate_str_to_rads_list.keys():
            if deriv_gate_str != 'prior':
                for has_neg_polarity in [False, True]:
                    small_bit = cum_nbits
                    cum_nbits += self.parent_num_qbits+1
                    large_bit = cum_nbits
                    self.subckt_to_large_small_bit_pair[(deriv_gate_str,
                        has_neg_polarity)] = (large_bit, small_bit)
            else:
                small_bit = cum_nbits
                cum_nbits += self.parent_num_qbits+1
                large_bit = cum_nbits
                self.subckt_to_large_small_bit_pair[(deriv_gate_str,
                    None)] = (large_bit, small_bit)

        assert tot_num_qbits == cum_nbits,\
            str(tot_num_qbits) + ' and ' + str(cum_nbits) + ' should be equal'

        for subckt, bit_pair in \
                self.subckt_to_large_small_bit_pair.items():
            deriv_gate_str, has_neg_polarity = subckt
            num_qbits_bef = self.parent_num_qbits + 1
            num_qbits_aft = tot_num_qbits
            bit_map = [k+bit_pair[1] for k in range(num_qbits_bef)]
            pre_emb = CktEmbedder(num_qbits_bef, num_qbits_aft, bit_map)
            compo_emb = CktEmbedder.composition(self.emb, pre_emb)
            wr = StairsDerivCkt_writer(
                deriv_gate_str, has_neg_polarity, self.deriv_direc,
                self.dpart_name,
                self.gate_str_to_rads_list,
                self.file_prefix, compo_emb,
                english_out=self.english_out, picture_out=self.picture_out)

    def get_coef_of_dpart(self, subckt,
            deriv_direc, dpart_name, var_num_to_rads=None):
        """
        This method returns a dict mapping each subckt (labeled by a pair: (
        deriv_gate_str, has_neg_polarity)) to its coefficient of dpart (either
        p1, ps or -p1*ps).

        var_num_to_rads is used if self wrote the English file with #int
        string symbols for placeholder variables. var_num_to_rads is used to
        float those strings. This is necessary before analytical calculation
        of the output of this method can proceed.

        Parameters
        ----------
        subckt : (str, bool)
        deriv_direc : int
        dpart_name : str
        var_num_to_rads : dict[int, float]

        Returns
        -------
        dict[(str, bool), float]

        """
        deriv_gate_str = subckt[0]
        t_list = self.gate_str_to_rads_list[deriv_gate_str]
        coef_dp = StairsDerivCkt_writer.get_coef_of_dpart(
                t_list, deriv_direc, dpart_name, var_num_to_rads)
        return coef_dp

    def get_fun_name_to_fun(self, deriv_direc, dpart_name):
        """
        This method returns a dictionary fun_name_to_fun mapping the
        function name to function, for all functions defined by this class.
        It combines the fun_name_to_fun of all the subckts.

        Parameters
        ----------
        deriv_direc : int
        dpart_name : str

        Returns
        -------
        dict[str, function]

        """
        fun_name_to_fun = {}
        for subckt, pair in self.subckt_to_large_small_bit_pair.items():
            deriv_gate_str = subckt[0]
            t_list = self.gate_str_to_rads_list[deriv_gate_str]
            fun_name_to_fun.update(
                StairsDerivCkt_writer.get_fun_name_to_fun(
                    t_list, deriv_direc, dpart_name))
        return fun_name_to_fun

    @staticmethod
    def sta_get_tot_num_qbits(parent_num_qbits, gate_str_to_rads_list):
        """
        This static (sta) method returns the total number of qbits for the
        quantum circuit generated by this class, i.e., it returns the sum of
        the qbits used by each subcircuit.

        Parameters
        ----------
        parent_num_qbits : int
        gate_str_to_rads_list : dict[str, list[float]]

        Returns
        -------
        int

        """
        num_subckts = len(gate_str_to_rads_list)*2
        if 'prior' in gate_str_to_rads_list:
            # subtract one because derivative of prior has has_neg_polarity
            # in [None] instead of in [False, True]
            num_subckts -= 1
        return num_subckts*(parent_num_qbits+1)

    def get_tot_num_qbits(self):
        """
        This is the self version of sta_get_tot_num_qbits().

        Returns
        -------
        int

        """
        return StairsDerivThrCkt_writer.\
            sta_get_tot_num_qbits(self.parent_num_qbits,
                                 self.gate_str_to_rads_list)


if __name__ == "__main__":
    def main():
        parent_num_qbits = 3

        # u2_bit_to_higher_bits = None
        u2_bit_to_higher_bits = {0: [2], 1: [2], 2: []}
        gate_str_to_rads_list = StairsCkt_writer.\
            get_gate_str_to_rads_list(parent_num_qbits, '#int',
                rads_const=np.pi/2,
                u2_bit_to_higher_bits=u2_bit_to_higher_bits)

        file_prefix = 'stairs_deriv_thr_writer_test'

        tot_num_qbits = StairsDerivThrCkt_writer.sta_get_tot_num_qbits(
            parent_num_qbits, gate_str_to_rads_list)
        print("tot_num_qbits=", tot_num_qbits)
        emb = CktEmbedder(tot_num_qbits, tot_num_qbits)

        for deriv_direc, dpart_name in [(0, 'single'), (3, 's')]:
            wr = StairsDerivThrCkt_writer(deriv_direc,
                                          dpart_name,
                                          gate_str_to_rads_list,
                                          file_prefix, emb)
            wr.close_files()
            print("%%%%%%%%%%%%%%%%%%%%%%%%%%")
            wr.print_eng_file()
            wr.print_pic_file()
    main()
